import torch
import torch.nn as nn
import torch.nn.functional as F
torch.pi = torch.acos(torch.zeros(1)).item() * 2

def get_regularizer(reg):
    if reg == 'ALP':
        return mse_loss
    elif reg == 'KL':
        return kldiv_loss
    else:
        raise NotImplementedError
             
def mse_loss(x1,x2):
    return 0.25*F.mse_loss(x1,x2,reduction='mean')
def kldiv_loss(x1,x2):
    eps = 1e-08
    sm1 = F.log_softmax(x1,dim=1)+eps
    sm2 = F.log_softmax(x2,dim=1)+eps
    #print(sm1.sum(dim=1))
    #print(sm2.sum(dim=1))
    #loss = F.kl_div(sm1,sm2, reduction='batchmean',log_target=True)
    #print(loss)
    return F.kl_div(sm1,sm2, reduction='batchmean',log_target=True)

class SpatialAttacker(nn.Module):
    """
    Attacker class, used to make adversarial examples using rotations.
    """
    def __init__(self,method,args,adv=True):
        """
        Initialize the Attacker

        Args:
            nn.Module model : the PyTorch model to attack
       """
        super(SpatialAttacker, self).__init__()
        #
        self.method = method
        self.adv = adv #if true generate adversarial example
        self.criterion = args.adv_criterion
        # Attack parameters
        self.limits = args.rotation_range
        
        # Attack method parameters
        if self.method == 'grid':
            self.granularity = args.granularity
        elif self.method == 'random':
            self.random_tries = args.k
        elif self.method == 'fo':
            # To Do implement first order attack
            raise NotImplementedError
        else:
            raise NotImplementedError

        
    def getrotmats(self,angles):
        cose = torch.cos(angles).unsqueeze(0)
        sine = torch.sin(angles).unsqueeze(0)
        zer0s = torch.zeros_like(cose)
        rotmats = torch.cat([torch.cat([cose,-sine,zer0s],0).unsqueeze(0),
                             torch.cat([sine,cose,zer0s],0).unsqueeze(0)],0).permute(2,0,1)
        return rotmats
    
    def compute_worst_of_k(self,rad,x,labels,num_angles,criterion, model):
        #print(x.shape)
        b,c,m,n = x.shape
        loss_track = torch.zeros(b,num_angles).cuda()
        worst_of_k = torch.zeros(b)
        with torch.no_grad():
            model.eval()
            outputs0 = model(x)
            #print(outputs0.shape)
            for i in range(num_angles):
                rotmats = self.getrotmats(rad[:,i])
                grid = F.affine_grid(rotmats, x.shape)
                rotated = F.grid_sample(x, grid, mode='bilinear', padding_mode='zeros')
                outputs = model(rotated)
                #print('outputs',outputs.shape)
                if self.criterion == 'xent':
                    loss_track[:,i] = criterion(outputs,labels)
            _,indices=loss_track.max(dim=1)  
            for i in range(b):
                worst_of_k[i] = rad[i,indices[i]]
            rotmats = self.getrotmats(worst_of_k)
            grid = F.affine_grid(rotmats, x.shape).to(x.device)
            rotated = F.grid_sample(x, grid, mode='bilinear', padding_mode='zeros')    
        return rotated,worst_of_k 

    def forward(self, x, labels, model):
        """
        Find adversarial examples using rotation, return transformed examples, worst case rotation
        """
        orig_input = x.detach()
        orig_input = orig_input.cuda()
        b,c,m,n = orig_input.shape
        if self.adv:
            # Initialize attacker criterion
            if self.criterion == 'xent':
                criterion = nn.CrossEntropyLoss(reduction='none').cuda()

            if self.method == 'random':
                rad=(torch.pi/180.)*torch.empty(b,self.random_tries).uniform_(-self.limits,
                                                                                 self.limits).type(torch.FloatTensor).cuda()
                rotated,worst_of_k = self.compute_worst_of_k(rad,orig_input,labels,self.random_tries,criterion, model)
            elif self.method == 'grid':
                rad = (torch.pi*torch.arange(-self.limits,self.limits,2*self.limits/self.granularity)/180.).repeat(b,1)
                rotated,worst_of_k = self.compute_worst_of_k(rad,orig_input,labels,self.granularity,criterion, model)
            else:
                raise NotImplementedError
            
            return rotated, worst_of_k
        else:
            return x, torch.zeros(b, device=x.device)
            

